{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "--2018-06-19 09:29:30--  https://ikpublictutorial.blob.core.windows.net/deeplearningframeworks/DenseNet_121.caffemodel\n",
      "Resolving ikpublictutorial.blob.core.windows.net... 52.239.158.74\n",
      "Connecting to ikpublictutorial.blob.core.windows.net|52.239.158.74|:443... connected.\n",
      "HTTP request sent, awaiting response... 304 The condition specified using HTTP conditional header(s) is not met.\n",
      "File ‘DenseNet_121.caffemodel’ not modified on server. Omitting download.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "# Clone of https://github.com/shicai/DenseNet-Caffe\n",
    "wget -N https://ikpublictutorial.blob.core.windows.net/deeplearningframeworks/DenseNet_121.caffemodel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "MULTI_GPU = True  # TOGGLE THIS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import time\n",
    "import multiprocessing\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import chainer\n",
    "import chainer.functions as F\n",
    "import chainer.links as L\n",
    "import collections\n",
    "from chainer import optimizers, cuda, dataset, training\n",
    "from chainer.training import extensions, updaters, StandardUpdater\n",
    "from chainer.dataset import concat_examples\n",
    "from chainer.links.caffe import CaffeFunction\n",
    "from sklearn.metrics.ranking import roc_auc_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from PIL import Image\n",
    "from chainercv import transforms\n",
    "import random\n",
    "from common.utils import download_data_chextxray, get_imgloc_labels, get_train_valid_test_split\n",
    "from common.utils import compute_roc_auc, get_cuda_version, get_cudnn_version, get_gpu_name\n",
    "from common.params_dense import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Performance Improvement\n",
    "# 1. Auto-tune\n",
    "# This adds very little now .. not sure if True by default?\n",
    "chainer.cuda.set_max_workspace_size(512 * 1024 * 1024)\n",
    "chainer.global_config.autotune = True\n",
    "chainer.global_config.type_check = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS:  linux\n",
      "Python:  3.5.4 |Anaconda custom (64-bit)| (default, Nov 20 2017, 18:44:38) \n",
      "[GCC 7.2.0]\n",
      "Chainer:  4.1.0\n",
      "CuPy:  4.1.0\n",
      "Numpy:  1.14.1\n",
      "GPU:  ['Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB', 'Tesla V100-PCIE-16GB']\n",
      "CUDA Version 9.0.176\n",
      "CuDNN Version  7.0.5\n"
     ]
    }
   ],
   "source": [
    "print(\"OS: \", sys.platform)\n",
    "print(\"Python: \", sys.version)\n",
    "print(\"Chainer: \", chainer.__version__)\n",
    "print(\"CuPy: \", chainer.cuda.cupy.__version__)\n",
    "print(\"Numpy: \", np.__version__)\n",
    "print(\"GPU: \", get_gpu_name())\n",
    "print(get_cuda_version())\n",
    "print(\"CuDNN Version \", get_cudnn_version())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPUs:  24\n",
      "GPUs:  4\n"
     ]
    }
   ],
   "source": [
    "CPU_COUNT = multiprocessing.cpu_count()\n",
    "GPU_COUNT = len(get_gpu_name())\n",
    "print(\"CPUs: \", CPU_COUNT)\n",
    "print(\"GPUs: \", GPU_COUNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0, 1, 2, 3)\n"
     ]
    }
   ],
   "source": [
    "DEVICES=tuple(list(range(GPU_COUNT)))\n",
    "print(DEVICES)\n",
    "if MULTI_GPU:\n",
    "    from cupy.cuda import nccl  # Test that nccl works for multi-gpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "chestxray/images chestxray/Data_Entry_2017.csv\n"
     ]
    }
   ],
   "source": [
    "# Model-params\n",
    "IMAGENET_RGB_MEAN_CAFFE = np.array([123.68, 116.78, 103.94], dtype=np.float32)\n",
    "IMAGENET_SCALE_FACTOR_CAFFE = 0.017\n",
    "# Paths\n",
    "CSV_DEST = \"chestxray\"\n",
    "IMAGE_FOLDER = os.path.join(CSV_DEST, \"images\")\n",
    "LABEL_FILE = os.path.join(CSV_DEST, \"Data_Entry_2017.csv\")\n",
    "print(IMAGE_FOLDER, LABEL_FILE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0001\n"
     ]
    }
   ],
   "source": [
    "# Manually scale to multi-gpu\n",
    "if MULTI_GPU:\n",
    "    # Seems both batch and LR are scaled\n",
    "    pass\n",
    "print(LR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Please make sure to download\n",
      "https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy\n",
      "Data already exists\n",
      "CPU times: user 543 ms, sys: 228 ms, total: 771 ms\n",
      "Wall time: 772 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Download data\n",
    "print(\"Please make sure to download\")\n",
    "print(\"https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-linux#download-and-install-azcopy\")\n",
    "download_data_chextxray(CSV_DEST)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Data Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class XrayData(dataset.DatasetMixin):\n",
    "    def __init__(self, patient_ids, height=HEIGHT, width=WIDTH,\n",
    "                 imagenet_mean=IMAGENET_RGB_MEAN_CAFFE, imagenet_scaling = IMAGENET_SCALE_FACTOR_CAFFE,\n",
    "                 img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, augmentation=None):\n",
    "          \n",
    "        self.img_locs, self.labels = get_imgloc_labels(img_dir, lbl_file, patient_ids)\n",
    "        self.augmentation = augmentation\n",
    "        self.imagenet_mean = imagenet_mean\n",
    "        self.imagenet_scaling = imagenet_scaling\n",
    "        self.h = height\n",
    "        self.w = width\n",
    "        print(\"Loaded {} labels and {} images\".format(len(self.labels), len(self.img_locs)))\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.img_locs)   \n",
    "    \n",
    "    def get_example(self, idx):\n",
    "        im_file = self.img_locs[idx]\n",
    "        # RGB Image\n",
    "        im_rgb = Image.open(im_file)\n",
    "        im_rgb = self._apply_data_preprocessing(im_rgb)\n",
    "        label = self.labels[idx]\n",
    "        if self.augmentation is not None:\n",
    "            # Random crop to 224, random flip\n",
    "            im_rgb = self._apply_data_augmentation(im_rgb)\n",
    "        else:\n",
    "            # Train/Val resize from 264 to 224\n",
    "            im_rgb = transforms.resize(im_rgb, size=(self.h, self.w))\n",
    "        return np.array(im_rgb, dtype=np.float32), np.array(label, dtype=np.int32)\n",
    "    \n",
    "    def _apply_data_preprocessing(self, rgb_im):\n",
    "        # Array\n",
    "        im = np.asarray(rgb_im, dtype=np.float32)\n",
    "        # (w, h, c) to (c, h, w)\n",
    "        im = im.transpose(2, 0, 1)\n",
    "        # Caffe normalisation\n",
    "        im -= self.imagenet_mean[:, None, None]\n",
    "        im *= self.imagenet_scaling\n",
    "        return im\n",
    "\n",
    "    def _apply_data_augmentation(self, im):\n",
    "        im = transforms.random_crop(im, size=(self.h,self.w))\n",
    "        im = transforms.random_flip(im)\n",
    "        return im"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train:21563 valid:3080 test:6162\n"
     ]
    }
   ],
   "source": [
    "train_set, valid_set, test_set = get_train_valid_test_split(TOT_PATIENT_NUMBER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 87306 labels and 87306 images\n",
      "Loaded 7616 labels and 7616 images\n",
      "Loaded 17198 labels and 17198 images\n"
     ]
    }
   ],
   "source": [
    "train_dataset = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=train_set, augmentation=True)\n",
    "valid_dataset = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=valid_set, augmentation=False)\n",
    "test_dataset  = XrayData(img_dir=IMAGE_FOLDER, lbl_file=LABEL_FILE, patient_ids=test_set, augmentation=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def truncate_bn(sym):\n",
    "    # Need to truncate batchnorm - eps\n",
    "    for layer in list(sym._children):\n",
    "        if \"bn\" in layer:\n",
    "            if sym.__dict__[layer].eps < 1e-5:\n",
    "                sym.__dict__[layer].eps = 1e-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CaffeFunctionDenseNet121(CaffeFunction):\n",
    "    \n",
    "    # https://github.com/ilkarman/DeepLearningFrameworks/pull/103\n",
    "    # PR just combines those two layers (BN and Scale) into a single BN,\n",
    "    # and then it improves the computational time in forward pass with ~67%\n",
    "    def __init__(self, model_path):\n",
    "        super(CaffeFunctionDenseNet121, self).__init__(model_path)\n",
    "        _prev_bn = None\n",
    "        _scale_layer_indices = []\n",
    "        for i, (func_name, bottom, top) in enumerate(self.layers):\n",
    "            if 'bn' in func_name:\n",
    "                _prev_bn = func_name\n",
    "            if 'scale' in func_name:\n",
    "                bn = getattr(self, _prev_bn)\n",
    "                scale = getattr(self, func_name)\n",
    "                with bn.init_scope():\n",
    "                    bn.gamma = chainer.Parameter(scale.W.array)\n",
    "                delattr(self, func_name)\n",
    "                _scale_layer_indices.append(i)\n",
    "        for i in sorted(_scale_layer_indices, reverse=True):\n",
    "            del self.forwards[self.layers[i][0]]\n",
    "            del self.layers[i]\n",
    "        \n",
    "    # Standard function saves all variables so cannot use big batch\n",
    "    # This lets me run BATCH of 56 over 32 - still can't get to 64\n",
    "    # https://github.com/chainer/chainer/blob/master/chainer/links/caffe/caffe_function.py#L176\n",
    "    def __call__(self, inputs, **kwargs):\n",
    "        variables = dict(inputs)\n",
    "        # Pools not to save\n",
    "        # These layers are not concatenated\n",
    "        _NOSAVE = set(['pool5', 'concat_5_16', 'concat_4_24', 'concat_3_12', 'concat_2_6'])\n",
    "        # Forward through all layers\n",
    "        for func_name, bottom, top in self.layers:\n",
    "\n",
    "            func = self.forwards[func_name]\n",
    "            # Concat ops require some previous layers that are saved\n",
    "            if \"concat\" in func_name:\n",
    "                input_vars = tuple([variables[bottom[0]], variables['data']])\n",
    "            else:\n",
    "                input_vars = tuple([variables['data']])\n",
    "            output_vars = func(*input_vars)\n",
    "            # Delete layers for concat once used\n",
    "            if \"concat\" in func_name:\n",
    "                del variables[bottom[0]]\n",
    "            if not isinstance(output_vars, collections.Iterable):\n",
    "                output_vars = output_vars,\n",
    "            # Save to dict\n",
    "            variables['data'] = output_vars[0]\n",
    "            top = top[0]\n",
    "            # Save for concat\n",
    "            if (\"pool\" in top) and (top not in _NOSAVE):\n",
    "                variables[top] = output_vars[0]\n",
    "            elif (\"concat\" in top) and (top not in _NOSAVE):\n",
    "                variables[top] = output_vars[0]\n",
    "                \n",
    "        return tuple([variables['data']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DenseNet121(chainer.Chain):\n",
    "    # Class to wrap base (up to pool5 output)\n",
    "    def __init__(self, base_symbol, n_classes=14):\n",
    "        super(DenseNet121, self).__init__()\n",
    "        with self.init_scope():\n",
    "            self.base_symbol = base_symbol\n",
    "            self.fc = L.Linear(1024, n_classes)\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        h = self.base_symbol(inputs={'data':x}, \n",
    "                             outputs=['pool5'])[0]\n",
    "        return self.fc(h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_symbol(model_name='densenet121', multi_gpu=MULTI_GPU, classes=CLASSES):\n",
    "    if model_name == 'densenet121':\n",
    "        # Load base\n",
    "        #base_symbol = CaffeFunction(\"DenseNet_121.caffemodel\")\n",
    "        base_symbol = CaffeFunctionDenseNet121(\"DenseNet_121.caffemodel\")\n",
    "        # Fix batch-norm\n",
    "        truncate_bn(base_symbol)\n",
    "        # Remove unused \n",
    "        base_symbol.__delattr__('fc6')\n",
    "        del base_symbol.forwards['fc6']\n",
    "        del base_symbol.layers[-1]\n",
    "        m = DenseNet121(base_symbol, classes)\n",
    "    else:\n",
    "        raise ValueError(\"Unknown model-name\")\n",
    "    # CUDA\n",
    "    if not multi_gpu:\n",
    "        print(\"One GPU\")\n",
    "        chainer.cuda.get_device(0).use()  # Make a specified GPU current\n",
    "        m.to_gpu()  \n",
    "    return m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_symbol(sym, lr=LR):\n",
    "    opt = optimizers.Adam(alpha=lr, beta1=0.9, beta2=0.999)\n",
    "    opt.setup(sym)\n",
    "    return opt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_roc_auc(data_gt, data_pd, full=True, classes=CLASSES):\n",
    "    roc_auc = []\n",
    "    for i in range(classes):\n",
    "        roc_auc.append(roc_auc_score(data_gt[:, i], data_pd[:, i]))\n",
    "    print(\"Full AUC\", roc_auc)\n",
    "    roc_auc = np.mean(roc_auc)\n",
    "    return roc_auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lossfun(x, t):\n",
    "    return F.sigmoid_cross_entropy(x, t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 732 ms, sys: 60.1 ms, total: 792 ms\n",
      "Wall time: 792 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Load symbol\n",
    "predictor = get_symbol()\n",
    "chexnet_sym = L.Classifier(predictor, lossfun=lossfun)\n",
    "# Won't work for multi-class\n",
    "chexnet_sym.compute_accuracy = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1 ms, sys: 119 µs, total: 1.12 ms\n",
      "Wall time: 1.13 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Load optimiser\n",
    "optimizer = init_symbol(chexnet_sym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data-iterators\n",
    "# Seems optimal to not use all available processes but around 6\n",
    "if MULTI_GPU:\n",
    "    train_iters = [\n",
    "        chainer.iterators.MultiprocessIterator(\n",
    "            i, BATCHSIZE, n_prefetch=10, n_processes=2) \n",
    "        for i in chainer.datasets.split_dataset_n_random(train_dataset, len(DEVICES))]\n",
    "else:\n",
    "    train_iter = chainer.iterators.MultiprocessIterator(\n",
    "        train_dataset, BATCHSIZE, n_prefetch=10, n_processes=CPU_COUNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# These can have a higher batch-size than train since no grads stored\n",
    "valid_iter = chainer.iterators.MultiprocessIterator(\n",
    "    valid_dataset, BATCHSIZE, repeat=False, shuffle=False, n_prefetch=10, n_processes=6)\n",
    "test_iter = chainer.iterators.MultiprocessIterator(\n",
    "    test_dataset, BATCHSIZE, repeat=False, shuffle=False, n_prefetch=10, n_processes=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda/envs/py35/lib/python3.5/site-packages/chainer/training/updaters/multiprocess_parallel_updater.py:138: UserWarning: optimizer.eps is changed to 4e-08 by MultiprocessParallelUpdater for new batch size.\n",
      "  format(optimizer.eps))\n"
     ]
    }
   ],
   "source": [
    "# MultiprocessParallelUpdater requires NCCL.\n",
    "# https://github.com/nvidia/nccl#build--run\n",
    "if MULTI_GPU:\n",
    "    updater = updaters.MultiprocessParallelUpdater(train_iters, optimizer, devices=DEVICES)\n",
    "else:\n",
    "    updater = StandardUpdater(train_iter, optimizer, device=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_interval = (1, 'epoch')\n",
    "trainer = training.Trainer(updater, stop_trigger=(EPOCHS, 'epoch'))\n",
    "trainer.extend(extensions.Evaluator(valid_iter, chexnet_sym, device=DEVICES[0]), trigger=val_interval)\n",
    "trainer.extend(extensions.LogReport(trigger=val_interval))\n",
    "trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'main/loss', 'validation/main/loss']), \n",
    "               trigger=val_interval)\n",
    "trainer.extend(extensions.ProgressBar(update_interval=500))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch       iteration   main/loss   validation/main/loss\n",
      "\u001b[J1           342         0.204333    0.836164              \n",
      "\u001b[J     total [##############....................................] 29.32%\n",
      "this epoch [#######################...........................] 46.61%\n",
      "       500 iter, 1 epoch / 5 epochs\n",
      "       inf iters/sec. Estimated time to finish: 0:00:00.\n",
      "\u001b[4A\u001b[J2           683         0.154302    0.471195              \n",
      "\u001b[J     total [#############################.....................] 58.65%\n",
      "this epoch [##############################################....] 93.23%\n",
      "      1000 iter, 2 epoch / 5 epochs\n",
      "    3.5665 iters/sec. Estimated time to finish: 0:03:17.717039.\n",
      "\u001b[4A\u001b[J3           1024        0.149126    0.2357                \n",
      "\u001b[J4           1365        0.145293    0.207207              \n",
      "\u001b[J     total [###########################################.......] 87.97%\n",
      "this epoch [###################...............................] 39.84%\n",
      "      1500 iter, 4 epoch / 5 epochs\n",
      "    3.4371 iters/sec. Estimated time to finish: 0:00:59.688867.\n",
      "\u001b[4A\u001b[J5           1706        0.142412    0.193563              \n",
      "\u001b[JCPU times: user 8min 27s, sys: 21.1 s, total: 8min 48s\n",
      "Wall time: 8min 25s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 1 GPU = 28min 49s\n",
    "# 4 GPU = 8min 25s\n",
    "trainer.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Test CheXNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 27.2 s, sys: 973 ms, total: 28.1 s\n",
      "Wall time: 1min 19s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "y_truth = test_dataset.labels\n",
    "y_guess = []\n",
    "test_iter.reset()\n",
    "with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):\n",
    "    for test_batch in test_iter:\n",
    "        # Data\n",
    "        x_test, y_test = concat_examples(test_batch, device=DEVICES[0])\n",
    "        # Prediction (need to apply sigmoid to turn into probability)\n",
    "        pred = cuda.to_cpu(F.sigmoid(predictor(x_test)).data)\n",
    "        # Collect results\n",
    "        y_guess.append(pred)           \n",
    "# Concatenate\n",
    "y_guess = np.concatenate(y_guess, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Full AUC [0.615187835633225, 0.46654025243218605, 0.5944393444665546, 0.48702925510951905, 0.7003878224524447, 0.6053779649645415, 0.5190046463761602, 0.6796954792237812, 0.5327447269687331, 0.45756121386975235, 0.5552081164821705, 0.495231128949573, 0.46582532802981974, 0.5414502054803485]\n",
      "Test AUC: 0.5511\n"
     ]
    }
   ],
   "source": [
    "# 1 GPU AUC: 0.8022\n",
    "# 4 GPU AUC: 0.5511 ? Seems like a very big fall\n",
    "print(\"Test AUC: {0:.4f}\".format(compute_roc_auc(y_truth, y_guess)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################################\n",
    "## Synthetic Data (Pure Training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "87296\n"
     ]
    }
   ],
   "source": [
    "# Test on fake-data -> no IO lag\n",
    "batch_in_epoch = len(train_dataset.labels)//BATCHSIZE\n",
    "tot_num = batch_in_epoch * BATCHSIZE\n",
    "print(tot_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RandomData(dataset.DatasetMixin):\n",
    "    def __init__(self, tot_num, cls = CLASSES):\n",
    "        self.fake_X = np.random.rand(tot_num, 3, 224, 224).astype(np.float32)\n",
    "        self.fake_y = np.random.rand(tot_num, cls).astype(np.int32)\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.fake_y)\n",
    "    \n",
    "    def get_example(self, idx):\n",
    "        return self.fake_X[idx], self.fake_y[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = init_symbol(chexnet_sym)\n",
    "train_dataset = RandomData(tot_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda/envs/py35/lib/python3.5/site-packages/chainer/training/updaters/multiprocess_parallel_updater.py:138: UserWarning: optimizer.eps is changed to 4e-08 by MultiprocessParallelUpdater for new batch size.\n",
      "  format(optimizer.eps))\n"
     ]
    }
   ],
   "source": [
    "if MULTI_GPU:\n",
    "    train_iters = [\n",
    "        chainer.iterators.MultiprocessIterator(\n",
    "            i, BATCHSIZE, n_prefetch=10, n_processes=1) \n",
    "        for i in chainer.datasets.split_dataset_n_random(train_dataset, len(DEVICES))]\n",
    "    updater = updaters.MultiprocessParallelUpdater(train_iters, optimizer, devices=DEVICES)\n",
    "else:\n",
    "    train_iter = chainer.iterators.MultiprocessIterator(\n",
    "        train_dataset, BATCHSIZE, n_prefetch=10, n_processes=1)\n",
    "    updater = StandardUpdater(train_iter, optimizer, device=0)\n",
    "    \n",
    "trainer = training.Trainer(updater, stop_trigger=(EPOCHS, 'epoch'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7min 12s, sys: 32 s, total: 7min 44s\n",
      "Wall time: 7min 23s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 1 GPU - Synthetic data: 27min 19s\n",
    "# 4 GPU - Synthetic data: 7min 23s\n",
    "trainer.run()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py35]",
   "language": "python",
   "name": "conda-env-py35-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
